package com.echatim.broker.localsvc;

import com.annotation.MethodFor;
import com.annotation.TopicFor;
import com.broker.base.protocol.ProtocolMessage;
import com.broker.base.protocol.request.RequestMessage;
import com.broker.base.protocol.response.Resp;
import com.broker.base.protocol.response.ResponseMessage;
import com.commom.AppRespError;
import com.echatim.ApplicationWrapper;
import com.echatim.filter.ProtocolHandlerInterceptor;
import com.echatim.filter.TopicInterceptor;
import com.exception.FormValidationException;
import com.exception.SocketGlobalExceptionAdvice;
import com.utils.Beans;
import com.utils.Pair;
import com.utils.ProtocolAnnotationUtils;
import com.utils.Streams;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.lang.NonNull;
import org.springframework.stereotype.Component;

import javax.validation.ConstraintViolation;
import javax.validation.Valid;
import javax.validation.ValidationException;
import javax.validation.Validator;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

/**
 * @author kong <androidsimu@163.com>
 * create by 2019/2/21 13:18
 * Description: 转发从客户端接收到的信息到业务层处理
 **/


@Slf4j
@ConditionalOnProperty(name="echatim.sdk.auth-type", havingValue="community")
@Component("topicDispatcher")
public class TopicDispatcher {
    @Data
    private static class InvokeInfo{
        private Object bean;
        private Method method;
        private MethodFor  methodAnnotation;
        private List<Pair<Object, Boolean>> methodParameterAndInject;

        public boolean isEmpty(){
            return bean == null ||
                    method == null ||
                    methodAnnotation == null ||
                    (methodParameterAndInject != null && methodParameterAndInject.stream().noneMatch(Pair::getSecond));
        }
    }


    public SocketGlobalExceptionAdvice getGlobalExceptionAdvice() {
        return ApplicationWrapper.getContext().getBean(SocketGlobalExceptionAdvice.class);
    }

    public Validator getValidator() {
        return ApplicationWrapper.getContext().getBean(Validator.class);
    }

    public ProtocolHandlerInterceptor getTopicInterceptor() {
        return ApplicationWrapper.getContext().getBean(TopicInterceptor.class);
    }


    public ResponseMessage consumeMessage(RequestMessage requestMessage){
        log.info("consumeMessage message:" + Beans.json(requestMessage.getBody()));
        InvokeInfo invokeInfo = getTargetInvokeInfo(requestMessage);
        if(invokeInfo.isEmpty()){
            log.info("Topic:{}, Method: {} Couldn't invoke target method.", requestMessage.getTopic(), requestMessage.getMethod());
//            log.info(" time consume:{} ms", System.currentTimeMillis()-t1);
            if((invokeInfo.methodParameterAndInject != null && invokeInfo.methodParameterAndInject.stream().noneMatch(Pair::getSecond))){
                log.error("Except inject {}, but not found in parameter {}",
                        invokeInfo.methodAnnotation.consumer(),
                        invokeInfo.methodParameterAndInject.stream().map(v->v.getFirst().getClass()).collect(Collectors.toList()));
                System.err.println(String.format("Except inject %s, but not found in parameter %s",
                        invokeInfo.methodAnnotation.consumer(),
                        invokeInfo.methodParameterAndInject.stream().map(v->v.getFirst().getClass()).collect(Collectors.toList())));
            }

            return ResponseMessage.failed(requestMessage.getRequestId(), "该API仅专业版支持");
        }
        // 校验添加了 @Valid 的参数
        Object parameterInject = Streams.find(invokeInfo.methodParameterAndInject, Pair::getSecond).get().getFirst();
        Object parameterInjectAfter = Beans.beans(Beans.json(requestMessage.getBody()), parameterInject.getClass());
        if(parameterInjectAfter instanceof ProtocolMessage){
            ((ProtocolMessage)parameterInjectAfter).setClientId(requestMessage.getProtocolMessage().getClientId());
        }
        try {
            validForm(parameterInjectAfter, invokeInfo.method, parameterInject.getClass());
        }
        catch (ValidationException e){
            return new ResponseMessage<>()
                    .setRequestId(requestMessage.getProtocolMessage().getRequestId())
                    .setResponse(getGlobalExceptionAdvice().onHandle(e));
        }
        // 注入参数
        Streams.forEachIndexed(invokeInfo.methodParameterAndInject.stream(), (index, v)->{
            if(v.getFirst().getClass() == parameterInject.getClass()){
                invokeInfo.methodParameterAndInject.set(index, new Pair<>(parameterInjectAfter, true));
            }
        });

        try {
            ResponseMessage responseMessage = ResponseMessage.succeed(requestMessage.getRequestId());
            // 应用拦截器
            filterApply(requestMessage, responseMessage, (request, response)->{
                long t1 = System.currentTimeMillis();
                // 反射目标方法
                Object res = ProtocolAnnotationUtils.invokeMethod(invokeInfo.bean,
                        invokeInfo.method,
                        invokeInfo.methodParameterAndInject.stream().map(Pair::getFirst).toArray()
                );
                log.info(" ProtocolAnnotationUtils.invokeMethod consume:{} ms", System.currentTimeMillis()-t1);
                // 返回参数必须为 R.class
                if(!(res instanceof Resp)){
                    log.error("Topic:{}, Method: {} the return eventType expect:{} but is:{}",
                            request.getTopic(), request.getMethod(), Resp.class, res.getClass());
                    response.setResponse(Resp.failed(AppRespError.SERVICE_ERROR).toJSON());
                    return;
                }
                response.setResponse(((Resp) res).toJSON());
            }, invokeInfo.method);



            return responseMessage;

        }
        catch (Exception e){
            return new ResponseMessage<>()
                    .setRequestId(requestMessage.getProtocolMessage().getRequestId())
                    .setResponse(getGlobalExceptionAdvice().onHandle(e));
        }
    }





    @NonNull
    private InvokeInfo getTargetInvokeInfo(RequestMessage message){
        AtomicReference<Object> result = new AtomicReference<>();
        InvokeInfo invokeInfo = new InvokeInfo();

        ProtocolAnnotationUtils.getBeansWithAnnotation(TopicFor.class)
                .stream()
                .filter(v->v.getSecond().value().equals(message.getTopic())) // TopicFor.value == message.topic
                .forEach(beanAndAnnotation->{
                    ProtocolAnnotationUtils.getMethodWithAnnotation(MethodFor.class, beanAndAnnotation.getFirst().getClass())
                            .stream()
                            .filter(v->v.getSecond().value().equals(message.getMethod())) // MethodFor.value == message.method
                            .forEach(methodAndAnnotation->{
                        Method method = methodAndAnnotation.getFirst();
                        MethodFor methodFor = methodAndAnnotation.getSecond();

                        List<Object> params = ProtocolAnnotationUtils.getMethodInstanceParam(method);
                        if(params == null){
                            return;
                        }
                        invokeInfo.setBean(beanAndAnnotation.getFirst());
                        invokeInfo.setMethod(method);
                        if(methodFor.value().equals(message.getMethod())){
                            invokeInfo.setMethodAnnotation(methodFor);
                        }
                        invokeInfo.setMethodParameterAndInject(
                                params.stream().map(v->new Pair<>(v, false)).collect(Collectors.toList())
                        );
                        Streams.forEachIndexed(invokeInfo.getMethodParameterAndInject().stream(), (index, v)->{
                            if(v.getFirst().getClass() == methodFor.consumer()){
                                invokeInfo.getMethodParameterAndInject().set(index, new Pair<>(v.getFirst(), true));
                            }
                        });
                    });

                });

        return invokeInfo;
    }

    // 应用拦截器
    private void filterApply(RequestMessage requestMessage, ResponseMessage responseMessage,
                                        BiConsumer<RequestMessage, ResponseMessage> handler,
                             Object invokeMethod) throws Exception {
        // TODO: 使用spring 扫描方式获取拦截器.
        List<ProtocolHandlerInterceptor> protocolHandlerInterceptors = handlerInterceptors();
        boolean continueNextHandler = true;
        // 触发 preHandle 拦截
        for(ProtocolHandlerInterceptor protocolHandlerInterceptor: protocolHandlerInterceptors){
            continueNextHandler = protocolHandlerInterceptor.preHandle(requestMessage, responseMessage, invokeMethod);
            if(!continueNextHandler){
                Resp resp = Beans.beans(Beans.json(responseMessage.getResponse()), Resp.class);
                if(resp.ok()){
//                    responseMessage.setResponse(Resp.failed("请求被拦截"));
                    log.warn("请求被拦截");
                }
                return;
            }
        }
        handler.accept(requestMessage, responseMessage);
        // 触发 afterHandler 拦截
        for(ProtocolHandlerInterceptor protocolHandlerInterceptor: protocolHandlerInterceptors){
            protocolHandlerInterceptor.afterHandler(requestMessage, responseMessage, invokeMethod);
        }
    }

    private void validForm(Object requestForm, Method method, Class targetParamClz) throws ValidationException{
        Class paramClzWithValid = ProtocolAnnotationUtils.getParamClzWithAnnotationFor(method, Valid.class);
        if(paramClzWithValid != null && paramClzWithValid == targetParamClz){
            try {
                Set<ConstraintViolation<Object>> violations = getValidator().validate(requestForm);
                if(violations.size() > 0){
                    throw new FormValidationException(violations);
                }
            }
            catch (ValidationException e){
                throw e;
            }
        }
    }


    protected List<ProtocolHandlerInterceptor> handlerInterceptors(){
        return Arrays.asList(getTopicInterceptor());
    }
}
